package auth

import (
	"context"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"fmt"
	"gitee.com/lipore/plume/IdGenerator"
	"gitee.com/lipore/plume/errors"
	"gitee.com/lipore/plume/logger"
	"github.com/golang-jwt/jwt"
	"sync"
	"time"
)

type KeyInfo struct {
	Key      *ecdsa.PublicKey
	ExpireAt time.Time
	KeyId    string
}

func newKeyInfo(key *ecdsa.PublicKey, exp time.Time, kid string) *KeyInfo {
	return &KeyInfo{
		Key:      key,
		ExpireAt: exp,
		KeyId:    kid,
	}
}

type tokenEncoder struct {
	sync.Mutex
	privateKey           *ecdsa.PrivateKey
	currentKid           string
	maxKeyAge            time.Duration
	publicKeyStore       PublicKeyStore
	privateKeyExpireAt   time.Time
	cancelAutoRefreshKey func()
}

type tokenEncoderOptions struct {
	maxKeyAge time.Duration
}

func newTokenEncoder(ctx context.Context, store PublicKeyStore, options tokenEncoderOptions) *tokenEncoder {
	encoder := &tokenEncoder{
		maxKeyAge:      options.maxKeyAge,
		publicKeyStore: store,
	}
	encoder.refreshSigningKey(ctx)
	encoder.autoRefreshSigningKey(ctx)
	return encoder
}

func (encoder *tokenEncoder) autoRefreshSigningKey(ctx context.Context) {
	ctxWithCancel, cancel := context.WithCancel(ctx)
	if encoder.cancelAutoRefreshKey != nil {
		encoder.cancelAutoRefreshKey()
	}
	encoder.cancelAutoRefreshKey = cancel
	ttl := encoder.privateKeyExpireAt.Sub(time.Now())
	if ttl < 0 {
		ttl = 0
	}
	timer := time.NewTimer(ttl)
	go func() {
		for {
			select {
			case <-timer.C:
				encoder.refreshSigningKey(ctxWithCancel)
				timer.Reset(encoder.privateKeyExpireAt.Sub(time.Now()) - 5*time.Minute)
			case <-ctxWithCancel.Done():
				encoder.cancelAutoRefreshKey()
			}
		}
	}()
}

func (encoder *tokenEncoder) refreshSigningKey(ctx context.Context) {
	encoder.Lock()
	defer encoder.Unlock()
	if encoder.privateKeyExpireAt.Sub(time.Now().Add(5*time.Minute)) > 0 {
		logger.Infof("attend to refresh private key that not expire in %d minutes", 5)
	}
	priKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	pubKey := &priKey.PublicKey
	keyId := fmt.Sprintf("%d", IdGenerator.Next())
	keyInfo := newKeyInfo(pubKey, time.Now().Add(config.maxRefreshTokenAge+encoder.maxKeyAge), keyId)
	err := encoder.publicKeyStore.SavePublicKey(ctx, keyInfo)
	if err == nil {
		encoder.privateKey = priKey
		encoder.currentKid = keyId
		logger.Debugf("adding new key with keyId: %s", keyId)
		encoder.privateKeyExpireAt = time.Now().Add(encoder.maxKeyAge)
	} else {
		// if save public key failed, fallback to use old key, and just postpone refresh
		encoder.privateKeyExpireAt = time.Now().Add(1 * time.Hour)
	}
}

func (encoder *tokenEncoder) encode(claims jwt.Claims) (string, error) {
	encoder.Lock()
	defer encoder.Unlock()
	token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
	token.Header["kid"] = encoder.currentKid
	tokenString, err := token.SignedString(encoder.privateKey)
	return tokenString, err
}

type tokenDecoder struct {
	keyStore PublicKeyStore
}

func newTokenDecoder(keyStore PublicKeyStore) *tokenDecoder {
	return &tokenDecoder{
		keyStore: keyStore,
	}
}

func (decoder *tokenDecoder) Decode(token string, claims jwt.Claims) error {
	_, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
		publicKey := decoder.keyStore.LoadPublicKey(context.TODO(), token.Header["kid"].(string))
		if publicKey != nil {
			return publicKey, nil
		}
		err := errors.New("not found effected public key")
		logger.Warnf("%v", err)
		return nil, err
	})
	if err != nil {
		err = errors.WithMessage(err, "parse token failed")
		logger.Warnf("%v", err)
		return err
	}
	return nil
}
